In [533]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from math import log
from sklearn.feature_selection import SelectKBest, f_classif
from sklearn import preprocessing
from sklearn import metrics
import warnings

%matplotlib inline
import matplotlib.pyplot as plt

#defaults
plt.rcParams['figure.figsize'] = (20.0, 20.0)
plt.rcParams.update({'font.size': 10})
plt.rcParams['xtick.major.pad']='5'
plt.rcParams['ytick.major.pad']='5'

plt.style.use('ggplot')

warnings.filterwarnings("ignore", category=FutureWarning )
warnings.simplefilter(action='ignore', category=(UserWarning,RuntimeWarning,FutureWarning))
pd.set_option('display.max_columns', 500)
pd.set_option('display.max_rows', 500)
pd.options.mode.chained_assignment = None

For demonstration of open source Interpretable ML packages I have used UCI credit card default dataset.

UCI credit card default data: https://archive.ics.uci.edu/ml/datasets/default+of+credit+card+clients

The UCI credit card default data contains demographic and payment information about credit card customers in Taiwan in the year 2005. The data set contains 23 input variables:

LIMIT_BAL: Amount of given credit (NT dollar)
SEX: 1 = male; 2 = female
EDUCATION: 1 = graduate school; 2 = university; 3 = high school; 4 = others
MARRIAGE: 1 = married; 2 = single; 3 = others
AGE: Age in years
PAY_0, PAY_2 - PAY_6: History of past payment;
PAY_0 = the repayment status in September, 2005;
PAY_2 = the repayment status in August, 2005; ...; PAY_6 = the repayment status in April, 2005.
The measurement scale for the repayment status is: -1 = pay duly; 1 = payment delay for one month; 2 = payment delay for two months; ...; 8 = payment delay for eight months; 9 = payment delay for nine months and above.
BILL_AMT1 - BILL_AMT6: Amount of bill statement (NT dollar).
BILL_AMNT1 = amount of bill statement in September, 2005; BILL_AMT2 = amount of bill statement in August, 2005; ...; BILL_AMT6 = amount of bill statement in April, 2005.
PAY_AMT1 - PAY_AMT6: Amount of previous payment (NT dollar).
PAY_AMT1 = amount paid in September, 2005;
PAY_AMT2 = amount paid in August, 2005; ...; PAY_AMT6 = amount paid in April, 2005.

default payment next month: Target Variable - whether or not a customer defaulted on their credit card bill in late 2005. Default (1) and Non-Default (0)

Data Prepration and Feature Engineering

In [534]:
data = pd.read_excel('default of credit card clients.xls', skiprows=1)
data = data.rename(columns={'default payment next month': 'DEFAULT_NEXT_MONTH'})
In [535]:
data.head(2)
Out[535]:
ID LIMIT_BAL SEX EDUCATION MARRIAGE AGE PAY_0 PAY_2 PAY_3 PAY_4 PAY_5 PAY_6 BILL_AMT1 BILL_AMT2 BILL_AMT3 BILL_AMT4 BILL_AMT5 BILL_AMT6 PAY_AMT1 PAY_AMT2 PAY_AMT3 PAY_AMT4 PAY_AMT5 PAY_AMT6 DEFAULT_NEXT_MONTH
0 1 20000 2 2 1 24 2 2 -1 -1 -2 -2 3913 3102 689 0 0 0 0 689 0 0 0 0 1
1 2 120000 2 2 2 26 -1 2 0 0 0 2 2682 1725 2682 3272 3455 3261 0 1000 1000 1000 0 2000 1
In [536]:
for var in ['PAY_0', 'PAY_2', 'PAY_3', 'PAY_4', 'PAY_5', 'PAY_6']:
    data[var] = data[var].apply(lambda x: 0 if x<=0 else x)
    
# Caluclate average payment and log of payment
data['PAY_AMT_AVG']     = data[[col for col in data.columns if col.startswith('PAY_AMT') ]].mean(axis=1)
data['PAY_AMT_AVG_log'] = data['PAY_AMT_AVG'].apply(lambda x: log(x+1))
#
for i in np.arange(1,7):
    data['PAY_REL_AMT_'+str(i)] = data['PAY_AMT'+str(i)]/data['PAY_AMT_AVG']
#
# log of payments
for i in np.arange(1,7):
    data['PAY_REL_AMT_log_'+str(i)] = data['PAY_AMT'+str(i)].apply(lambda x: log(x+1))
    
#Feature Engineering of bill amount
data['BILL_AMT_AVG'] = data[[col for col in data.columns if col.startswith('BILL_AMT')]].mean(axis=1)
data['BILL_AMT_AVG_log'] = data['BILL_AMT_AVG'].apply(lambda x: log(x+1) if x>0 else 0)
#
# bill sign as a separate feature
for i in np.arange(1,7):
    data['BILL_AMT_SIGN_'+str(i)] = data['BILL_AMT'+str(i)].apply(lambda x: float(x>0))
#
#log of credit limit
data['LIMIT_BAL_log'] = data['LIMIT_BAL'].apply(lambda x: log(x+1))
data['LIMIT_BAL_CAT'] = pd.cut(data['LIMIT_BAL'], range(0, np.max(data['LIMIT_BAL']), 10000),  right=False)
#
data['SEX'] = data['SEX'].astype('category').cat.rename_categories(['M', 'F'])
data['MARRIAGE'] = data['MARRIAGE'].astype('category').cat.rename_categories(['NA', 'MARRIED', 'SINGLE', 'OTHER'])
data['AGE_CAT'] = pd.cut(data['AGE'], range(0, 100, 10), right=False)
education_dict = {0:'other', 1:'graduate school', 2:'university', 3:'high school', 4:'other', 5:'other', 6:'other'}
data['EDUCATION'] = data['EDUCATION'].apply(lambda i: education_dict[i]) 
In [537]:
data = pd.get_dummies(data, columns=[col for col in data.columns if data[col].dtypes not in ['int64', 'float64']])
In [538]:
data.head(2)
Out[538]:
ID LIMIT_BAL AGE PAY_0 PAY_2 PAY_3 PAY_4 PAY_5 PAY_6 BILL_AMT1 BILL_AMT2 BILL_AMT3 BILL_AMT4 BILL_AMT5 BILL_AMT6 PAY_AMT1 PAY_AMT2 PAY_AMT3 PAY_AMT4 PAY_AMT5 PAY_AMT6 DEFAULT_NEXT_MONTH PAY_AMT_AVG PAY_AMT_AVG_log PAY_REL_AMT_1 PAY_REL_AMT_2 PAY_REL_AMT_3 PAY_REL_AMT_4 PAY_REL_AMT_5 PAY_REL_AMT_6 PAY_REL_AMT_log_1 PAY_REL_AMT_log_2 PAY_REL_AMT_log_3 PAY_REL_AMT_log_4 PAY_REL_AMT_log_5 PAY_REL_AMT_log_6 BILL_AMT_AVG BILL_AMT_AVG_log BILL_AMT_SIGN_1 BILL_AMT_SIGN_2 BILL_AMT_SIGN_3 BILL_AMT_SIGN_4 BILL_AMT_SIGN_5 BILL_AMT_SIGN_6 LIMIT_BAL_log SEX_M SEX_F EDUCATION_graduate school EDUCATION_high school EDUCATION_other EDUCATION_university MARRIAGE_NA MARRIAGE_MARRIED MARRIAGE_SINGLE MARRIAGE_OTHER LIMIT_BAL_CAT_[0, 10000) LIMIT_BAL_CAT_[10000, 20000) LIMIT_BAL_CAT_[20000, 30000) LIMIT_BAL_CAT_[30000, 40000) LIMIT_BAL_CAT_[40000, 50000) LIMIT_BAL_CAT_[50000, 60000) LIMIT_BAL_CAT_[60000, 70000) LIMIT_BAL_CAT_[70000, 80000) LIMIT_BAL_CAT_[80000, 90000) LIMIT_BAL_CAT_[90000, 100000) LIMIT_BAL_CAT_[100000, 110000) LIMIT_BAL_CAT_[110000, 120000) LIMIT_BAL_CAT_[120000, 130000) LIMIT_BAL_CAT_[130000, 140000) LIMIT_BAL_CAT_[140000, 150000) LIMIT_BAL_CAT_[150000, 160000) LIMIT_BAL_CAT_[160000, 170000) LIMIT_BAL_CAT_[170000, 180000) LIMIT_BAL_CAT_[180000, 190000) LIMIT_BAL_CAT_[190000, 200000) LIMIT_BAL_CAT_[200000, 210000) LIMIT_BAL_CAT_[210000, 220000) LIMIT_BAL_CAT_[220000, 230000) LIMIT_BAL_CAT_[230000, 240000) LIMIT_BAL_CAT_[240000, 250000) LIMIT_BAL_CAT_[250000, 260000) LIMIT_BAL_CAT_[260000, 270000) LIMIT_BAL_CAT_[270000, 280000) LIMIT_BAL_CAT_[280000, 290000) LIMIT_BAL_CAT_[290000, 300000) LIMIT_BAL_CAT_[300000, 310000) LIMIT_BAL_CAT_[310000, 320000) LIMIT_BAL_CAT_[320000, 330000) LIMIT_BAL_CAT_[330000, 340000) LIMIT_BAL_CAT_[340000, 350000) LIMIT_BAL_CAT_[350000, 360000) LIMIT_BAL_CAT_[360000, 370000) LIMIT_BAL_CAT_[370000, 380000) LIMIT_BAL_CAT_[380000, 390000) LIMIT_BAL_CAT_[390000, 400000) LIMIT_BAL_CAT_[400000, 410000) LIMIT_BAL_CAT_[410000, 420000) LIMIT_BAL_CAT_[420000, 430000) LIMIT_BAL_CAT_[430000, 440000) LIMIT_BAL_CAT_[440000, 450000) LIMIT_BAL_CAT_[450000, 460000) LIMIT_BAL_CAT_[460000, 470000) LIMIT_BAL_CAT_[470000, 480000) LIMIT_BAL_CAT_[480000, 490000) LIMIT_BAL_CAT_[490000, 500000) LIMIT_BAL_CAT_[500000, 510000) LIMIT_BAL_CAT_[510000, 520000) LIMIT_BAL_CAT_[520000, 530000) LIMIT_BAL_CAT_[530000, 540000) LIMIT_BAL_CAT_[540000, 550000) LIMIT_BAL_CAT_[550000, 560000) LIMIT_BAL_CAT_[560000, 570000) LIMIT_BAL_CAT_[570000, 580000) LIMIT_BAL_CAT_[580000, 590000) LIMIT_BAL_CAT_[590000, 600000) LIMIT_BAL_CAT_[600000, 610000) LIMIT_BAL_CAT_[610000, 620000) LIMIT_BAL_CAT_[620000, 630000) LIMIT_BAL_CAT_[630000, 640000) LIMIT_BAL_CAT_[640000, 650000) LIMIT_BAL_CAT_[650000, 660000) LIMIT_BAL_CAT_[660000, 670000) LIMIT_BAL_CAT_[670000, 680000) LIMIT_BAL_CAT_[680000, 690000) LIMIT_BAL_CAT_[690000, 700000) LIMIT_BAL_CAT_[700000, 710000) LIMIT_BAL_CAT_[710000, 720000) LIMIT_BAL_CAT_[720000, 730000) LIMIT_BAL_CAT_[730000, 740000) LIMIT_BAL_CAT_[740000, 750000) LIMIT_BAL_CAT_[750000, 760000) LIMIT_BAL_CAT_[760000, 770000) LIMIT_BAL_CAT_[770000, 780000) LIMIT_BAL_CAT_[780000, 790000) LIMIT_BAL_CAT_[790000, 800000) LIMIT_BAL_CAT_[800000, 810000) LIMIT_BAL_CAT_[810000, 820000) LIMIT_BAL_CAT_[820000, 830000) LIMIT_BAL_CAT_[830000, 840000) LIMIT_BAL_CAT_[840000, 850000) LIMIT_BAL_CAT_[850000, 860000) LIMIT_BAL_CAT_[860000, 870000) LIMIT_BAL_CAT_[870000, 880000) LIMIT_BAL_CAT_[880000, 890000) LIMIT_BAL_CAT_[890000, 900000) LIMIT_BAL_CAT_[900000, 910000) LIMIT_BAL_CAT_[910000, 920000) LIMIT_BAL_CAT_[920000, 930000) LIMIT_BAL_CAT_[930000, 940000) LIMIT_BAL_CAT_[940000, 950000) LIMIT_BAL_CAT_[950000, 960000) LIMIT_BAL_CAT_[960000, 970000) LIMIT_BAL_CAT_[970000, 980000) LIMIT_BAL_CAT_[980000, 990000) AGE_CAT_[0, 10) AGE_CAT_[10, 20) AGE_CAT_[20, 30) AGE_CAT_[30, 40) AGE_CAT_[40, 50) AGE_CAT_[50, 60) AGE_CAT_[60, 70) AGE_CAT_[70, 80) AGE_CAT_[80, 90)
0 1 20000 24 2 2 0 0 0 0 3913 3102 689 0 0 0 0 689 0 0 0 0 1 114.833333 4.752152 0.0 6.0 0.0 0.0 0.0 0.0 0.0 6.536692 0.000000 0.000000 0.0 0.000000 1284.000000 7.158514 1.0 1.0 1.0 0.0 0.0 0.0 9.903538 0 1 0 0 0 1 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0
1 2 120000 26 0 2 0 0 0 2 2682 1725 2682 3272 3455 3261 0 1000 1000 1000 0 2000 1 833.333333 6.726633 0.0 1.2 1.2 1.2 0.0 2.4 0.0 6.908755 6.908755 6.908755 0.0 7.601402 2846.166667 7.954080 1.0 1.0 1.0 1.0 1.0 1.0 11.695255 0 1 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0
In [539]:
#Imputing the missing values  
for name in data.columns:
    if pd.isnull(data[name]).sum() > 0:
        median = data[name].median()
        data[name] = data[name].apply(lambda x: median if pd.isnull(x) else x)
#
y = data['DEFAULT_NEXT_MONTH']
X = data.drop(['DEFAULT_NEXT_MONTH', 'ID'], axis = 1) 
#
selector = SelectKBest(f_classif, 30)
selector.fit(X, y)
top_indices = np.nan_to_num(selector.scores_).argsort()[-30:][::-1]
selector.scores_[top_indices]
X_prep = X[X.columns[top_indices]]
scaler = preprocessing.MinMaxScaler()
X_transformed = scaler.fit(X_prep).transform(X_prep)
X_transformed_df = pd.DataFrame(X_transformed,columns=X.columns[top_indices])
X_train, X_test, y_train, y_test = train_test_split(X_transformed_df, y, test_size=0.2, random_state=1)
In [540]:
X_train.head(2)
Out[540]:
PAY_0 PAY_2 PAY_3 PAY_4 PAY_5 PAY_6 LIMIT_BAL_log PAY_REL_AMT_log_1 LIMIT_BAL PAY_AMT_AVG_log PAY_REL_AMT_log_2 PAY_REL_AMT_log_3 PAY_REL_AMT_log_4 PAY_REL_AMT_log_6 PAY_REL_AMT_log_5 PAY_AMT_AVG LIMIT_BAL_CAT_[20000, 30000) LIMIT_BAL_CAT_[30000, 40000) PAY_AMT1 PAY_AMT2 PAY_AMT4 PAY_AMT3 LIMIT_BAL_CAT_[10000, 20000) PAY_AMT5 PAY_AMT6 EDUCATION_graduate school BILL_AMT_AVG_log EDUCATION_other LIMIT_BAL_CAT_[500000, 510000) SEX_M
28004 0.125 0.25 0.0 0.0 0.0 0.0 0.150507 0.000000 0.010101 0.647476 0.617825 0.504077 0.738600 0.524260 0.671094 0.009040 1.0 0.0 0.000000 0.004172 0.030596 0.001116 0.0 0.014067 0.001892 0.0 0.713694 0.0 0.0 0.0
8560 0.000 0.00 0.0 0.0 0.0 0.0 0.349475 0.412669 0.040404 0.431403 0.408792 0.420182 0.431731 0.437005 0.444242 0.000504 0.0 0.0 0.000323 0.000208 0.000509 0.000353 0.0 0.000741 0.000598 0.0 0.419512 0.0 0.0 0.0

Open Source Techniques for Interpretable Machine Learning

1. Permutation Feature Importance (PFI)

(Global Feature Importance Method)

PFI is an algorithm that computes importance scores for each of the feature variables of the dataset. The importance measures are determined by computing the senstivity of a model to random permutations of feature values.

In other words an importance score quantifies the contribution of a certain feature to the predictive performance of a model in terms of how much a choosen evaluation metric deviates after permuting the values of that feature.

Intuition behind a permutation importance is that a feature is “important” if altering or permuting its values increases the model error, because in this case the model relied on the feature for the prediction. A feature is “unimportant” if altering or permuting or shuffling its values leaves the model error unchanged, because in this case the model ignored the feature for the prediction.

Permutation Feature Importance provides global insight into model behavior.

In [1]:
from IPython.display import Image
Image("PFI.png")
Out[1]:
In [398]:
# Train Random Forest Classifier
from sklearn.ensemble import RandomForestClassifier
rf_model = RandomForestClassifier().fit(X_train, y_train)
y_pred_rf = rf_model.predict_proba(X_test)
#
fpr, tpr, thresholds = metrics.roc_curve(y_test, y_pred_rf[:,1:2])
print('RF AUC on Test Data', metrics.auc(fpr, tpr))
# Using ELI5 Calculate Feature Importance using Permutation Importance
import eli5
from eli5.sklearn import PermutationImportance

perm = PermutationImportance(rf_model,scoring = 'roc_auc').fit(X_test, y_test)
eli5.show_weights(perm, top=100,  feature_names = X_test.columns.tolist())
RF AUC on Test Data 0.7132271573781153
Out[398]:
Weight Feature
0.0759 ± 0.0054 PAY_0
0.0288 ± 0.0067 PAY_AMT_AVG
0.0264 ± 0.0072 PAY_AMT_AVG_log
0.0198 ± 0.0058 PAY_AMT3
0.0168 ± 0.0042 PAY_AMT1
0.0134 ± 0.0066 PAY_REL_AMT_log_1
0.0132 ± 0.0019 PAY_REL_AMT_log_2
0.0123 ± 0.0106 BILL_AMT_AVG_log
0.0116 ± 0.0046 PAY_2
0.0110 ± 0.0052 PAY_AMT4
0.0102 ± 0.0038 PAY_REL_AMT_log_4
0.0093 ± 0.0049 PAY_AMT2
0.0087 ± 0.0035 PAY_REL_AMT_log_3
0.0045 ± 0.0022 PAY_6
0.0033 ± 0.0020 PAY_3
0.0021 ± 0.0079 LIMIT_BAL
0.0020 ± 0.0014 PAY_4
0.0011 ± 0.0010 LIMIT_BAL_CAT_[20000, 30000)
0.0006 ± 0.0074 PAY_REL_AMT_log_5
0.0006 ± 0.0016 LIMIT_BAL_CAT_[30000, 40000)
0.0006 ± 0.0003 EDUCATION_other
-0.0001 ± 0.0002 LIMIT_BAL_CAT_[500000, 510000)
-0.0007 ± 0.0004 LIMIT_BAL_CAT_[10000, 20000)
-0.0008 ± 0.0030 PAY_5
-0.0014 ± 0.0024 EDUCATION_graduate school
-0.0017 ± 0.0036 PAY_REL_AMT_log_6
-0.0018 ± 0.0046 SEX_M
-0.0026 ± 0.0051 LIMIT_BAL_log
-0.0036 ± 0.0031 PAY_AMT5
-0.0038 ± 0.0045 PAY_AMT6

Interpreting Permutation Importance

The values towards the top are the most important features, and those towards the bottom matter least. The first number in each row shows how much model performance decreased with a random shuffling. The number after the ± measures how performance varied from one-reshuffling to the next.

Some time we see negative values for permutation importances. In those cases, the predictions on the shuffled (or noisy) data happened to be more accurate than the real data. This happens when the feature didn't matter (should have had an importance close to 0), but random chance caused the predictions on shuffled data to be more accurate.

2. LIME - Local Interpretable Model-Agnostic Explanations

LIME (Ribeiro et. al. 2016) is an implmentation of local surrogate model and is used to explain individual predictions of black box models. Surrogate models are trained to approximate the predictions of the underlying black box model.

The idea is quite intuitive. First, forget about the training data and imagine you only have the black box model where you can input data points and get the predictions of the model. You can probe the box as often as you want. Your goal is to understand why the machine learning model made a certain prediction. LIME tests what happens to the predictions when you give variations of your data into the machine learning model. LIME generates a new dataset consisting of permuted samples and the corresponding predictions of the black box model. On this new dataset LIME then trains an interpretable model, which is weighted by the proximity of the sampled instances to the instance of interest. The interpretable model can be anything for example Lasso or a decision tree. The learned model should be a good approximation of the machine learning model predictions locally, but it does not have to be a good global approximation. This kind of accuracy is also called local fidelity.

Algorithm:

  1. Generate a fake dataset from the example we’re going to explain.

  2. Use black-box estimator to get target values for each example in a generated dataset (e.g. class probabilities).

  3. Train a new white-box estimator, using generated dataset and generated labels as training data. It means we’re trying to create an estimator which works the same as a black-box estimator, but which is easier to inspect. It doesn’t have to work well globally, but it must approximate the black-box model well in the area close to the original example.

  4. To express “area close to the original example” user must provide a distance/similarity metric for examples in a generated dataset. Then training data is weighted according to a distance from the original example - the further is example, the less it affects weights of a white-box estimator.

  5. Explain the original example through weights of this white-box estimator instead.

  6. Prediction quality of a white-box classifer shows how well it approximates the black-box classifier. If the quality is low then explanation shouldn’t be trusted.

Example of LIME technique using LIME Package:

In [427]:
from xgboost import XGBClassifier
xgb_model = XGBClassifier()
xgb_model.fit(X_train.values, y_train)
#
y_pred_xgb = xgb_model.predict_proba(X_test.values)
#
fpr, tpr, thresholds = metrics.roc_curve(y_test, y_pred_xgb[:,1:2])
print('XGB AUC on Test Data', metrics.auc(fpr, tpr))
XGB AUC on Test Data 0.7798818368508691
In [431]:
#all categorical columns
categorical_cols = [col for col in X_test.columns if X_test[col].dtypes not in ['int64', 'float64']]
#all feature names
all_cols = X_test.columns
import lime
from lime.lime_tabular import LimeTabularExplainer
explainer = lime.lime_tabular.LimeTabularExplainer(X_train.values ,class_names=[0, 1], feature_names = all_cols,
                                                   categorical_names=categorical_cols, mode='classification')
In [424]:
y_test[:5]
Out[424]:
10747    0
12573    1
29676    0
8856     1
21098    0
Name: DEFAULT_NEXT_MONTH, dtype: int64
In [429]:
y_pred_xgb[:,1:2][:5] #probabilities for default(1) are shown
Out[429]:
array([[0.07045205],
       [0.32993165],
       [0.10843958],
       [0.74137545],
       [0.34667933]], dtype=float32)
In [433]:
predict_fn = lambda x: xgb_model.predict_proba(x)
exp = explainer.explain_instance(X_test.values[3], predict_fn, num_features=10)
exp.show_in_notebook()
In [435]:
predict_fn = lambda x: xgb_model.predict_proba(x)
exp = explainer.explain_instance(X_test.values[4], predict_fn, num_features=10)
exp.show_in_notebook()

3. SHAP

SHAP – SHapley Additive exPlanations – explains the output of any machine learning model using Shapley values. Shapley values have been introduced in game theory since 1953 but only recently they have been used in the feature importance context. SHAP belongs to the family of “additive feature attribution methods”. This means that SHAP assigns a value to each feature for each prediction, the higher the value, the larger the feature’s attribution to the specific prediction. It also means that the sum of these values should be close to the original model prediction.

SHAP unified several existing feature attribution methods such as LIME, Deep Explainer and more and it theoretically guarantees that SHAP is the only additive feature attribution method with three desirable properties:

Local accuracy: The explanations are truthfully explaining the ML model

Missingness: Missing features have no attributed impact to the model predictions

Consistency: Consistency with human intuition (more technically, consistency states that if a model changes so that some input’s contribution increases or stays the same regardless of the other inputs, that input’s attribution should not decrease)

In [402]:
import shap

(A) SHAP for Tree Based Models (Xgboost)

In [403]:
from xgboost import XGBClassifier
xgb_model = XGBClassifier()
xgb_model.fit(X_train.values, y_train)
#
y_pred_xgb = xgb_model.predict_proba(X_test.values)
#
fpr, tpr, thresholds = metrics.roc_curve(y_test, y_pred_xgb[:,1:2])
print('XGB AUC on Test Data', metrics.auc(fpr, tpr))
XGB AUC on Test Data 0.7798818368508691
In [404]:
print(y_pred_xgb[:,1:2][:5]) #First 5 predicted rows
[[0.07045205]
 [0.32993165]
 [0.10843958]
 [0.74137545]
 [0.34667933]]
In [405]:
print(y_test[:5]) #First 5 actual values
10747    0
12573    1
29676    0
8856     1
21098    0
Name: DEFAULT_NEXT_MONTH, dtype: int64
For 4th Observation Actual Outcome = 1, Predicted Probability = 0.74137545
In [406]:
# load JS visualization code to notebook
shap.initjs()
explainer = shap.TreeExplainer(xgb_model)
shap_values = explainer.shap_values(X_test.values)

# visualize the first prediction's explanation
shap.force_plot(explainer.expected_value, shap_values[3,:], X_test.iloc[3,:])
Out[406]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
In [407]:
# SHAP shows log of odds in XGBOOST, so converting back to probability
1/(1+np.exp(-(explainer.expected_value + sum(shap_values[3,:]))))
Out[407]:
0.7413753860815816
In [408]:
explainer.expected_value, shap_values[3,:], X_test.iloc[3,:]
Out[408]:
(-1.3557515,
 array([ 1.4819235e+00,  3.2297644e-01,  9.1985993e-02,  6.3817576e-02,
         1.5514079e-01,  2.0185958e-01,  4.6059433e-03,  3.1939805e-02,
         0.0000000e+00,  5.1442567e-02,  2.4539832e-02, -1.7242026e-02,
         8.4280953e-02, -4.9721816e-04, -1.5902000e-02,  0.0000000e+00,
         1.0425673e-03,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
         0.0000000e+00, -1.6113481e-02, -6.4004011e-02,  8.0255521e-03,
         0.0000000e+00, -9.4135769e-04], dtype=float32),
 PAY_0                             0.250000
 PAY_2                             0.250000
 PAY_3                             0.375000
 PAY_4                             0.375000
 PAY_5                             0.375000
 PAY_6                             0.250000
 LIMIT_BAL_log                     0.451536
 PAY_REL_AMT_log_1                 0.600597
 LIMIT_BAL                         0.070707
 PAY_AMT_AVG_log                   0.556520
 PAY_REL_AMT_log_2                 0.514645
 PAY_REL_AMT_log_3                 0.538343
 PAY_REL_AMT_log_4                 0.000000
 PAY_REL_AMT_log_6                 0.559897
 PAY_REL_AMT_log_5                 0.569169
 PAY_AMT_AVG                       0.002683
 LIMIT_BAL_CAT_[20000, 30000)      0.000000
 LIMIT_BAL_CAT_[30000, 40000)      0.000000
 PAY_AMT1                          0.004236
 PAY_AMT2                          0.000950
 PAY_AMT4                          0.000000
 PAY_AMT3                          0.001786
 LIMIT_BAL_CAT_[10000, 20000)      0.000000
 PAY_AMT5                          0.003751
 PAY_AMT6                          0.003026
 EDUCATION_graduate school         0.000000
 BILL_AMT_AVG_log                  0.773543
 EDUCATION_other                   0.000000
 LIMIT_BAL_CAT_[500000, 510000)    0.000000
 SEX_M                             0.000000
 Name: 8856, dtype: float64)
For 5th Observation Actual Outcome = 0, Predicted Probability = 0.34667933
In [409]:
# visualize the first prediction's explanation
shap.force_plot(explainer.expected_value, shap_values[4,:], X_test.iloc[4,:])
Out[409]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
In [410]:
# SHAP shows log of odds in XGBOOST, so converting back to probability
1/(1+np.exp(-(explainer.expected_value + sum(shap_values[4,:]))))
Out[410]:
0.3466792325037667
In [411]:
explainer.expected_value, shap_values[4,:], X_test.iloc[4,:]
Out[411]:
(-1.3557515,
 array([ 2.0399170e-01,  5.8446687e-01, -5.0103180e-02, -1.5877679e-02,
         2.1611795e-01, -7.7224553e-02, -1.3069540e-01,  9.4784237e-02,
         0.0000000e+00,  1.8881747e-02,  1.1988448e-04, -5.6925524e-02,
         4.1078139e-02, -1.2233254e-03, -1.7266940e-02,  0.0000000e+00,
         1.0425673e-03,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
         0.0000000e+00, -1.2633407e-03, -1.1938670e-01,  1.1060676e-02,
         0.0000000e+00,  2.0506112e-02], dtype=float32),
 PAY_0                             0.125000
 PAY_2                             0.250000
 PAY_3                             0.000000
 PAY_4                             0.000000
 PAY_5                             0.250000
 PAY_6                             0.000000
 LIMIT_BAL_log                     0.715676
 PAY_REL_AMT_log_1                 0.000000
 LIMIT_BAL                         0.262626
 PAY_AMT_AVG_log                   0.555775
 PAY_REL_AMT_log_2                 0.578531
 PAY_REL_AMT_log_3                 0.554614
 PAY_REL_AMT_log_4                 0.000000
 PAY_REL_AMT_log_6                 0.576820
 PAY_REL_AMT_log_5                 0.586372
 PAY_AMT_AVG                       0.002657
 LIMIT_BAL_CAT_[20000, 30000)      0.000000
 LIMIT_BAL_CAT_[30000, 40000)      0.000000
 PAY_AMT1                          0.000000
 PAY_AMT2                          0.002375
 PAY_AMT4                          0.000000
 PAY_AMT3                          0.002232
 LIMIT_BAL_CAT_[10000, 20000)      0.000000
 PAY_AMT5                          0.004689
 PAY_AMT6                          0.003783
 EDUCATION_graduate school         1.000000
 BILL_AMT_AVG_log                  0.720214
 EDUCATION_other                   0.000000
 LIMIT_BAL_CAT_[500000, 510000)    0.000000
 SEX_M                             1.000000
 Name: 21098, dtype: float64)

Global Feature Importance using SHAP

To get an overview of which features are most important for a model we can plot the SHAP values of every feature. The plot below sorts features by the sum of SHAP value magnitudes over all samples, and uses SHAP values to show the distribution of the impacts each feature has on the model output. The color represents the feature value (red high, blue low).

In [413]:
# summarize the effects of all the features
shap_values = explainer.shap_values(X_train.values)
shap.summary_plot(shap_values, X_train)

Mean absolute value of the SHAP values for each feature

In [414]:
shap.summary_plot(shap_values, X_train, plot_type="bar")

(B) SHAP with DeepExplainer for Neural Network based models (Keras and tensorflow)

In [415]:
from keras.models import Sequential
from keras.layers.core import Dense, Activation, Dropout
input_dim = X_train.shape[1]

nn_model = Sequential()
nn_model.add(Dense(30, input_shape=(input_dim,), activation='relu'))
nn_model.add(Dense(15, activation='relu'))
nn_model.add(Dense(1,  activation='sigmoid'))

nn_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
#
history = nn_model.fit(np.array(X_train), np.array(y_train),
                       batch_size=25, epochs=5, verbose=2, 
                       validation_split=0.2)
Train on 19200 samples, validate on 4800 samples
Epoch 1/5
 - 2s - loss: 0.4746 - acc: 0.7890 - val_loss: 0.4287 - val_acc: 0.8187
Epoch 2/5
 - 1s - loss: 0.4436 - acc: 0.8091 - val_loss: 0.4259 - val_acc: 0.8221
Epoch 3/5
 - 1s - loss: 0.4407 - acc: 0.8127 - val_loss: 0.4242 - val_acc: 0.8225
Epoch 4/5
 - 1s - loss: 0.4393 - acc: 0.8137 - val_loss: 0.4241 - val_acc: 0.8260
Epoch 5/5
 - 1s - loss: 0.4386 - acc: 0.8176 - val_loss: 0.4268 - val_acc: 0.8269
In [416]:
score = nn_model.evaluate(np.array(X_test), np.array(y_test), verbose=0)
print('Test log loss:', score[0])
print('Test accuracy:', score[1])
nn_y_pred = nn_model.predict_on_batch(np.array(X_test))[:,0]
fpr, tpr, thresholds = metrics.roc_curve(y_test, nn_y_pred)
print('Neural Network AUC', metrics.auc(fpr, tpr))
Test log loss: 0.4369474064509074
Test accuracy: 0.8178333333333333
Neural Network AUC 0.7746758284757662
In [417]:
print(nn_y_pred[:5]) #First 5 predicted values
[0.12858841 0.345711   0.15886211 0.7064725  0.4697573 ]
In [418]:
print(y_test[:5]) #First 5 actual values
10747    0
12573    1
29676    0
8856     1
21098    0
Name: DEFAULT_NEXT_MONTH, dtype: int64
In [419]:
# select a set of background examples to take an expectation over
background = X_train.values[np.random.choice(X_train.shape[0], 100, replace=False)]

# explain predictions of the model
e = shap.DeepExplainer(nn_model, background)
For 4th Observation Actual Outcome = 1, Predicted Probability = 0.6936611
In [420]:
# load JS visualization code to notebook
shap.initjs()
shap.force_plot(e.expected_value, e.shap_values(X_test.values[3].reshape(1,-1))[0], X_test.iloc[3,:])
Out[420]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
In [421]:
e.expected_value, e.shap_values(X_test.values[3].reshape(1,-1))[0], X_test.iloc[3,:]
Out[421]:
(array([0.23331009], dtype=float32),
 array([[ 2.22899709e-01,  2.99180997e-02,  4.24762397e-02,
          4.76480403e-02,  7.27071978e-02,  5.99223470e-02,
          3.02256300e-03,  9.99260851e-03,  4.71108943e-03,
          3.32835405e-03,  5.77858314e-03,  5.94157816e-03,
         -1.67960276e-02, -2.19716104e-03,  2.72369395e-03,
         -2.14481622e-04,  5.02464647e-04,  1.44090676e-03,
          8.09926272e-05,  2.55247422e-04,  3.71933474e-04,
         -6.11045646e-04, -1.90646797e-04,  1.22682244e-04,
         -4.44647880e-04, -1.61280814e-02,  6.56499567e-04,
          0.00000000e+00, -7.52502307e-04, -4.00383859e-03]]),
 PAY_0                             0.250000
 PAY_2                             0.250000
 PAY_3                             0.375000
 PAY_4                             0.375000
 PAY_5                             0.375000
 PAY_6                             0.250000
 LIMIT_BAL_log                     0.451536
 PAY_REL_AMT_log_1                 0.600597
 LIMIT_BAL                         0.070707
 PAY_AMT_AVG_log                   0.556520
 PAY_REL_AMT_log_2                 0.514645
 PAY_REL_AMT_log_3                 0.538343
 PAY_REL_AMT_log_4                 0.000000
 PAY_REL_AMT_log_6                 0.559897
 PAY_REL_AMT_log_5                 0.569169
 PAY_AMT_AVG                       0.002683
 LIMIT_BAL_CAT_[20000, 30000)      0.000000
 LIMIT_BAL_CAT_[30000, 40000)      0.000000
 PAY_AMT1                          0.004236
 PAY_AMT2                          0.000950
 PAY_AMT4                          0.000000
 PAY_AMT3                          0.001786
 LIMIT_BAL_CAT_[10000, 20000)      0.000000
 PAY_AMT5                          0.003751
 PAY_AMT6                          0.003026
 EDUCATION_graduate school         0.000000
 BILL_AMT_AVG_log                  0.773543
 EDUCATION_other                   0.000000
 LIMIT_BAL_CAT_[500000, 510000)    0.000000
 SEX_M                             0.000000
 Name: 8856, dtype: float64)
For 5th Observation Actual Outcome = 0, Predicted Probability = 0.4703175
In [422]:
# load JS visualization code to notebook
shap.initjs()
shap.force_plot(e.expected_value, e.shap_values(X_test.values[4].reshape(1,-1))[0], X_test.iloc[4,:])
Out[422]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
In [423]:
e.expected_value, e.shap_values(X_test.values[4].reshape(1,-1))[0], X_test.iloc[4,:]
Out[423]:
(array([0.23331009], dtype=float32),
 array([[ 1.30873187e-01,  7.39333159e-02, -7.27364056e-03,
         -7.54562923e-03,  6.96027730e-02, -7.92923316e-03,
         -1.05606896e-02, -1.01188241e-02, -1.02823995e-02,
          3.21125447e-03,  3.81495325e-03, -2.83457449e-04,
          3.82677500e-03,  2.00560337e-03,  9.99903808e-04,
         -8.58655626e-05,  1.27160847e-04,  1.20974001e-03,
          8.41692255e-04,  1.82270570e-04,  4.48698995e-04,
         -5.94940406e-04,  3.48372012e-05,  1.26008005e-04,
         -6.84965738e-04, -1.06724219e-02, -1.63723552e-03,
          0.00000000e+00, -5.72409965e-04,  1.34507130e-02]]),
 PAY_0                             0.125000
 PAY_2                             0.250000
 PAY_3                             0.000000
 PAY_4                             0.000000
 PAY_5                             0.250000
 PAY_6                             0.000000
 LIMIT_BAL_log                     0.715676
 PAY_REL_AMT_log_1                 0.000000
 LIMIT_BAL                         0.262626
 PAY_AMT_AVG_log                   0.555775
 PAY_REL_AMT_log_2                 0.578531
 PAY_REL_AMT_log_3                 0.554614
 PAY_REL_AMT_log_4                 0.000000
 PAY_REL_AMT_log_6                 0.576820
 PAY_REL_AMT_log_5                 0.586372
 PAY_AMT_AVG                       0.002657
 LIMIT_BAL_CAT_[20000, 30000)      0.000000
 LIMIT_BAL_CAT_[30000, 40000)      0.000000
 PAY_AMT1                          0.000000
 PAY_AMT2                          0.002375
 PAY_AMT4                          0.000000
 PAY_AMT3                          0.002232
 LIMIT_BAL_CAT_[10000, 20000)      0.000000
 PAY_AMT5                          0.004689
 PAY_AMT6                          0.003783
 EDUCATION_graduate school         1.000000
 BILL_AMT_AVG_log                  0.720214
 EDUCATION_other                   0.000000
 LIMIT_BAL_CAT_[500000, 510000)    0.000000
 SEX_M                             1.000000
 Name: 21098, dtype: float64)

(C) SHAP Model agnostic with KernelExplainer (explains any function)

In [436]:
from sklearn import linear_model

# Create logistic regression object
model_logistic = linear_model.LogisticRegression()

# Train the model using the training sets
model_logistic.fit(X_train, y_train)
#
y_pred_regr = model_logistic.predict_proba(X_test.values)
#
fpr, tpr, thresholds = metrics.roc_curve(y_test, y_pred_regr[:,1:2])
print('Logistic Regression AUC on Test Data', metrics.auc(fpr, tpr))
Logistic Regression AUC on Test Data 0.7730130464191519
In [ ]:
# use Kernel SHAP to explain test set predictions
X_train_summary = shap.kmeans(X_train, 10)
explainer = shap.KernelExplainer(model_logistic.predict_proba, X_train_summary)
# As per shap documentation if training size is big we have to cluster using kmeans 
# rather than use the whole training set to estimate expected values, we summarize with
shap_values = explainer.shap_values(X_test, nsamples=100)

# plot the SHAP values for the Setosa output of the first instance
shap.force_plot(explainer.expected_value[0], shap_values[0][0,:], X_test.iloc[0,:])
In [438]:
explainer.expected_value[1]
Out[438]:
0.19040174988950823
In [439]:
 #Actual values first five
y_test[:5]
Out[439]:
10747    0
12573    1
29676    0
8856     1
21098    0
Name: DEFAULT_NEXT_MONTH, dtype: int64
In [440]:
#Predicted Probabilites by logsitic regression of 3rd, 4th, and 5th observations
y_pred_regr[:5][:,1:2]
Out[440]:
array([[0.10854273],
       [0.33027806],
       [0.14359415],
       [0.76505556],
       [0.35220043]])
For 4th Observation Actual Outcome = 1, Predicted Probability = 0.76505556
In [441]:
# plot the SHAP values
shap.force_plot(explainer.expected_value[1], shap_values[1][3,:], X_test.iloc[3,:])
Out[441]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
In [442]:
explainer.expected_value[1] + sum(shap_values[1][3,:])
Out[442]:
0.7650555635365744
In [443]:
explainer.expected_value[1], shap_values[1][3,:], X_test.iloc[3,:]
Out[443]:
(0.19040174988950823,
 array([0.29894122, 0.04073524, 0.05162284, 0.03162683, 0.06781257,
        0.06170632, 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.0222088 , 0.        , 0.        , 0.        ]),
 PAY_0                             0.250000
 PAY_2                             0.250000
 PAY_3                             0.375000
 PAY_4                             0.375000
 PAY_5                             0.375000
 PAY_6                             0.250000
 LIMIT_BAL_log                     0.451536
 PAY_REL_AMT_log_1                 0.600597
 LIMIT_BAL                         0.070707
 PAY_AMT_AVG_log                   0.556520
 PAY_REL_AMT_log_2                 0.514645
 PAY_REL_AMT_log_3                 0.538343
 PAY_REL_AMT_log_4                 0.000000
 PAY_REL_AMT_log_6                 0.559897
 PAY_REL_AMT_log_5                 0.569169
 PAY_AMT_AVG                       0.002683
 LIMIT_BAL_CAT_[20000, 30000)      0.000000
 LIMIT_BAL_CAT_[30000, 40000)      0.000000
 PAY_AMT1                          0.004236
 PAY_AMT2                          0.000950
 PAY_AMT4                          0.000000
 PAY_AMT3                          0.001786
 LIMIT_BAL_CAT_[10000, 20000)      0.000000
 PAY_AMT5                          0.003751
 PAY_AMT6                          0.003026
 EDUCATION_graduate school         0.000000
 BILL_AMT_AVG_log                  0.773543
 EDUCATION_other                   0.000000
 LIMIT_BAL_CAT_[500000, 510000)    0.000000
 SEX_M                             0.000000
 Name: 8856, dtype: float64)
For 5th Observation Actual Outcome = 0, Predicted Probability = 0.35220043
In [444]:
# plot the SHAP values
shap.force_plot(explainer.expected_value[1], shap_values[1][4,:], X_test.iloc[4,:])
Out[444]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
In [445]:
explainer.expected_value[1], shap_values[1][4,:], X_test.iloc[4,:]
Out[445]:
(0.19040174988950823,
 array([ 9.37841930e-02,  3.98830052e-02,  0.00000000e+00,  0.00000000e+00,
         3.47363460e-02,  0.00000000e+00, -3.09911874e-02,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         2.44570880e-02,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00, -7.07629904e-05,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00]),
 PAY_0                             0.125000
 PAY_2                             0.250000
 PAY_3                             0.000000
 PAY_4                             0.000000
 PAY_5                             0.250000
 PAY_6                             0.000000
 LIMIT_BAL_log                     0.715676
 PAY_REL_AMT_log_1                 0.000000
 LIMIT_BAL                         0.262626
 PAY_AMT_AVG_log                   0.555775
 PAY_REL_AMT_log_2                 0.578531
 PAY_REL_AMT_log_3                 0.554614
 PAY_REL_AMT_log_4                 0.000000
 PAY_REL_AMT_log_6                 0.576820
 PAY_REL_AMT_log_5                 0.586372
 PAY_AMT_AVG                       0.002657
 LIMIT_BAL_CAT_[20000, 30000)      0.000000
 LIMIT_BAL_CAT_[30000, 40000)      0.000000
 PAY_AMT1                          0.000000
 PAY_AMT2                          0.002375
 PAY_AMT4                          0.000000
 PAY_AMT3                          0.002232
 LIMIT_BAL_CAT_[10000, 20000)      0.000000
 PAY_AMT5                          0.004689
 PAY_AMT6                          0.003783
 EDUCATION_graduate school         1.000000
 BILL_AMT_AVG_log                  0.720214
 EDUCATION_other                   0.000000
 LIMIT_BAL_CAT_[500000, 510000)    0.000000
 SEX_M                             1.000000
 Name: 21098, dtype: float64)
In [446]:
explainer.expected_value[1] + sum(shap_values[1][4,:])
Out[446]:
0.35220043178854565

Probabilities by logistic regression - coff and interept

In [563]:
model_logistic.coef_ 
Out[563]:
array([[ 6.45808016,  0.91190808,  0.66395476,  0.39058927,  0.73718774,
         1.4301084 , -0.8919969 , -0.0979786 , -0.01609167, -0.61279452,
        -0.09615763, -0.30138543, -0.25291871, -0.00837795,  0.03441226,
        -0.61177212, -0.04269953, -0.05304702, -1.15223127, -0.79130129,
        -0.05231076,  0.37912682, -0.12432066, -0.46045983, -0.14052108,
        -0.04138196, -0.01060361, -1.04944079, -0.02610325,  0.08982138]])
In [564]:
model_logistic.intercept_
Out[564]:
array([-0.6737483])
In [566]:
import math
def sigmoid(x):
    return 1 / (1 + math.exp(-x))
In [577]:
#probabilites from coeff and intercept
[sigmoid(x) for x in (np.dot(X_test[:5].values, model_logistic.coef_.T ) + model_logistic.intercept_).reshape(1,-1)[0]]
Out[577]:
[0.10854272966996846, 0.33027806407643306, 0.14359415119158594, 0.7650555635365747, 0.3522004317885456]

4. LOCO - Leave-One-Covariate-Out

LOCO creates local interpretations for each row in a training or unlabeled score set by scoring the row of data once and then again for each input variable (e.g., covariate) in the row. In each additional scoring run, one input variable is set to missing, zero, its mean value, or another appropriate value for leaving it out of the prediction. The input variable with the largest absolute impact on the prediction for that row is taken to be the most important variable for that row’s prediction. Variables can also be ranked by their impact on the prediction on a per-row basis.

Its the most simplest of all and can deteriorate in accuracy when complex nonlinear dependencies exist in a model.

LOCO is the most simplest and seems to be not a good technique and can be implemented in straight forward way as shown below:

In [549]:
X_test_5_df = X_test[:5].copy()
In [550]:
#Original Predition for top 5 rows in test data
y_pred_5 = xgb_model.predict_proba(X_test_5_df.values)[:,1:2]
In [551]:
y_pred_5
Out[551]:
array([[0.07045205],
       [0.32993165],
       [0.10843958],
       [0.74137545],
       [0.34667933]], dtype=float32)
In [552]:
LOCO_df  = pd.DataFrame(columns = X_test_5_df.columns)
In [560]:
# Lets suppose if we want to get the feature attribution for top 5 observations in test data
for i in (X_test_5_df.columns):
    # train and predict with Xi set to missing
    X_test_5_df_cpy = X_test_5_df.copy()
    X_test_5_df_cpy[i] = np.nan
    #print(X_test_5_df_cpy.head())
    preds_y_test_5_df = xgb_model.predict_proba(X_test_5_df_cpy.values)[:,1:2]
    #print(preds_y_test_5_df)
    print('LOCO Progress: ' + i)
    # subtract the LOCO prediction from original prediction
    LOCO_df[i] = (y_pred_5 - preds_y_test_5_df).reshape(1,-1)[0]
LOCO Progress: PAY_0
LOCO Progress: PAY_2
LOCO Progress: PAY_3
LOCO Progress: PAY_4
LOCO Progress: PAY_5
LOCO Progress: PAY_6
LOCO Progress: LIMIT_BAL_log
LOCO Progress: PAY_REL_AMT_log_1
LOCO Progress: LIMIT_BAL
LOCO Progress: PAY_AMT_AVG_log
LOCO Progress: PAY_REL_AMT_log_2
LOCO Progress: PAY_REL_AMT_log_3
LOCO Progress: PAY_REL_AMT_log_4
LOCO Progress: PAY_REL_AMT_log_6
LOCO Progress: PAY_REL_AMT_log_5
LOCO Progress: PAY_AMT_AVG
LOCO Progress: LIMIT_BAL_CAT_[20000, 30000)
LOCO Progress: LIMIT_BAL_CAT_[30000, 40000)
LOCO Progress: PAY_AMT1
LOCO Progress: PAY_AMT2
LOCO Progress: PAY_AMT4
LOCO Progress: PAY_AMT3
LOCO Progress: LIMIT_BAL_CAT_[10000, 20000)
LOCO Progress: PAY_AMT5
LOCO Progress: PAY_AMT6
LOCO Progress: EDUCATION_graduate school
LOCO Progress: BILL_AMT_AVG_log
LOCO Progress: EDUCATION_other
LOCO Progress: LIMIT_BAL_CAT_[500000, 510000)
LOCO Progress: SEX_M
In [561]:
LOCO_df
Out[561]:
PAY_0 PAY_2 PAY_3 PAY_4 PAY_5 PAY_6 LIMIT_BAL_log PAY_REL_AMT_log_1 LIMIT_BAL PAY_AMT_AVG_log PAY_REL_AMT_log_2 PAY_REL_AMT_log_3 PAY_REL_AMT_log_4 PAY_REL_AMT_log_6 PAY_REL_AMT_log_5 PAY_AMT_AVG LIMIT_BAL_CAT_[20000, 30000) LIMIT_BAL_CAT_[30000, 40000) PAY_AMT1 PAY_AMT2 PAY_AMT4 PAY_AMT3 LIMIT_BAL_CAT_[10000, 20000) PAY_AMT5 PAY_AMT6 EDUCATION_graduate school BILL_AMT_AVG_log EDUCATION_other LIMIT_BAL_CAT_[500000, 510000) SEX_M
0 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 -0.071373 -0.030835 0.0 -0.029920 -0.004029 -0.025591 -0.006316 -0.022704 -0.004304 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.000000 -0.013381 0.0 0.0 0.007895
1 0.000000 0.000000 0.000000 0.000000 0.078735 0.012601 0.000000 0.000000 0.0 0.001610 0.000000 -0.013590 0.000000 0.000000 -0.009702 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.000000 0.049715 0.0 0.0 0.000000
2 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 -0.016417 0.000000 0.0 -0.007159 -0.051122 0.005328 -0.001451 0.012332 0.014991 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.000000 -0.030802 0.0 0.0 0.003683
3 0.259511 0.002420 -0.024566 -0.030968 0.043300 0.057227 -0.047013 0.002943 0.0 0.008993 0.000000 -0.003218 0.000000 0.001320 -0.017954 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.000000 0.216822 0.0 0.0 0.000000
4 0.097337 0.133826 0.000000 0.000000 0.063681 0.000000 -0.056171 0.000000 0.0 -0.093801 0.000000 -0.048913 0.000000 0.001555 -0.023171 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -0.004788 -0.021493 0.0 0.0 0.003639
In [562]:
X_test_5_df 
Out[562]:
PAY_0 PAY_2 PAY_3 PAY_4 PAY_5 PAY_6 LIMIT_BAL_log PAY_REL_AMT_log_1 LIMIT_BAL PAY_AMT_AVG_log PAY_REL_AMT_log_2 PAY_REL_AMT_log_3 PAY_REL_AMT_log_4 PAY_REL_AMT_log_6 PAY_REL_AMT_log_5 PAY_AMT_AVG LIMIT_BAL_CAT_[20000, 30000) LIMIT_BAL_CAT_[30000, 40000) PAY_AMT1 PAY_AMT2 PAY_AMT4 PAY_AMT3 LIMIT_BAL_CAT_[10000, 20000) PAY_AMT5 PAY_AMT6 EDUCATION_graduate school BILL_AMT_AVG_log EDUCATION_other LIMIT_BAL_CAT_[500000, 510000) SEX_M
10747 0.000 0.00 0.000 0.000 0.000 0.00 0.745676 0.659599 0.303030 0.622246 0.606806 0.605170 0.600243 0.576820 0.532942 0.006454 0.0 0.0 0.009496 0.003562 0.004831 0.004464 0.0 0.002345 0.003783 0.0 0.840306 0.0 0.0 1.0
12573 0.000 0.00 0.000 0.000 0.250 0.25 0.000000 0.000000 0.000000 0.465938 0.000000 0.579880 0.000000 0.000000 0.401860 0.000800 0.0 0.0 0.000000 0.000000 0.000000 0.003156 1.0 0.000427 0.000000 0.0 0.446160 0.0 0.0 0.0
29676 0.000 0.00 0.000 0.000 0.000 0.00 0.349475 0.000000 0.040404 0.724483 0.749258 0.561565 0.537585 0.549009 0.824756 0.025273 0.0 0.0 0.000000 0.027464 0.002093 0.002455 0.0 0.103128 0.002622 0.0 0.733884 0.0 0.0 1.0
8856 0.250 0.25 0.375 0.375 0.375 0.25 0.451536 0.600597 0.070707 0.556520 0.514645 0.538343 0.000000 0.559897 0.569169 0.002683 0.0 0.0 0.004236 0.000950 0.000000 0.001786 0.0 0.003751 0.003026 0.0 0.773543 0.0 0.0 0.0
21098 0.125 0.25 0.000 0.000 0.250 0.00 0.715676 0.000000 0.262626 0.555775 0.578531 0.554614 0.000000 0.576820 0.586372 0.002657 0.0 0.0 0.000000 0.002375 0.000000 0.002232 0.0 0.004689 0.003783 1.0 0.720214 0.0 0.0 1.0